{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "text_classification_solution.ipynb",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7H3yTncQfoym"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "pp-UaomMQJNo"
      },
      "source": [
        "# Multiclass text classification solution\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github.com/tensorflow/examples/blob/master/community/en/text_classification_solution.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/examples/tree/master/community/en/text_classification_solution.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "h1CiDh7CfqON",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NBhbug0Qgmuu"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This notebook contains a solution to the multiclass classification exercise from the [Basic text classification](https://www.tensorflow.org/tutorials/keras/text_classification) tutorial on [tensorflow.org](https://www.tensorflow.org/tutorials). This notebook contains contains few comments. To learn more, please refer to the tutorial."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "8RZOuS9LWQvv",
        "tags": [],
        "colab": {}
      },
      "source": [
        "!pip install tf-nightly"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "2ew7HTbPpCJH",
        "tags": [],
        "colab": {}
      },
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "from tensorflow.keras import layers\n",
        "from tensorflow.keras import losses\n",
        "from tensorflow.keras import preprocessing\n",
        "from tensorflow.keras.layers.experimental.preprocessing import TextVectorization"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Jw_wAEzInaL1",
        "colab": {}
      },
      "source": [
        "print(tf.__version__)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Jxfg2FcfsD8f"
      },
      "source": [
        "### Download the BigQuery dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "k8eRO0ng5r1A",
        "tags": [],
        "colab": {}
      },
      "source": [
        "url = \"http://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz\"\n",
        "\n",
        "tf.keras.utils.get_file(\"stack_overflow_16k.tar.gz\", url,\n",
        "                        untar=True, cache_dir='.',\n",
        "                        cache_subdir='')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "lkqtzHOb51ff",
        "tags": [],
        "colab": {}
      },
      "source": [
        "batch_size = 32\n",
        "\n",
        "raw_train_ds = tf.keras.preprocessing.text_dataset_from_directory(\n",
        "    'train', batch_size=batch_size, validation_split=0.2, subset='training', seed=42)\n",
        "\n",
        "raw_val_ds = tf.keras.preprocessing.text_dataset_from_directory(\n",
        "    'train', batch_size=batch_size, validation_split=0.2, subset='validation', seed=42)\n",
        "\n",
        "raw_test_ds = tf.keras.preprocessing.text_dataset_from_directory(\n",
        "    'test', batch_size=batch_size)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yRLs33z1tUi_"
      },
      "source": [
        "### Explore the data"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "wtOsEKAe32ji",
        "colab": {}
      },
      "source": [
        "print(raw_train_ds.class_names)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "vQlq9CmP6tJX",
        "tags": [],
        "colab": {}
      },
      "source": [
        "for text_batch, label_batch in raw_train_ds.take(1):\n",
        "  for i in range(5):\n",
        "    print(text_batch.numpy()[i])\n",
        "    print(label_batch.numpy()[i])\n",
        "    print()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qmXSZZBNt8aU"
      },
      "source": [
        "### Prepare data for training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "VKVrmuB-7fAm",
        "colab": {}
      },
      "source": [
        "max_features = 5000\n",
        "embedding_dim = 128\n",
        "sequence_length = 500\n",
        "\n",
        "vectorize_layer = TextVectorization(\n",
        "    max_tokens=max_features,\n",
        "    output_mode='int',\n",
        "    output_sequence_length=sequence_length)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "aqO9ZzPV7kOL",
        "colab": {}
      },
      "source": [
        "# Make a text-only dataset (no labels) and call adapt\n",
        "text_ds = raw_train_ds.map(lambda x, y: x)\n",
        "vectorize_layer.adapt(text_ds)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "fX6zq78yu3Fv"
      },
      "source": [
        "### Vectorize the data\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "H3gXSDzt7qni",
        "colab": {}
      },
      "source": [
        "def vectorize_text(text, label):\n",
        "  text = tf.expand_dims(text, -1)\n",
        "  return vectorize_layer(text), label\n",
        "\n",
        "train_ds = raw_train_ds.map(vectorize_text)\n",
        "val_ds = raw_val_ds.map(vectorize_text)\n",
        "test_ds = raw_test_ds.map(vectorize_text)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "nNHfxyfd2v-l"
      },
      "source": [
        "### Configure the dataset for performance"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "DBHgkpQuqfaj",
        "colab": {}
      },
      "source": [
        "AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
        "\n",
        "train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)\n",
        "val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)\n",
        "test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "evYhQerxvAq9"
      },
      "source": [
        "### Build the model\n",
        "\n",
        "This model has not be tuned in any way, the goal is to show you the mechanics."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "FMQEUVGM74iQ",
        "colab": {}
      },
      "source": [
        "model = tf.keras.Sequential([\n",
        "  layers.Embedding(max_features + 1, embedding_dim),\n",
        "  layers.Dropout(0.2),\n",
        "  layers.GlobalAveragePooling1D(),\n",
        "  layers.Dropout(0.2),\n",
        "  layers.Dense(4)])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "XJAZP7YS036K",
        "colab": {}
      },
      "source": [
        "model.compile(\n",
        "    loss=losses.SparseCategoricalCrossentropy(from_logits=True), \n",
        "    optimizer='adam', \n",
        "    metrics=['accuracy'])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_AHy6MOCvbm9"
      },
      "source": [
        "### Train the model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Gs78f_N58MO9",
        "tags": [],
        "colab": {}
      },
      "source": [
        "history = model.fit(\n",
        "    train_ds,\n",
        "    validation_data=val_ds,\n",
        "    epochs=5)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Hm8g4isKvebz"
      },
      "source": [
        "### Evaluate the model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "WbGYQ8sERUT7",
        "colab": {}
      },
      "source": [
        "loss, accuracy = model.evaluate(test_ds)\n",
        "\n",
        "print(\"Loss: \", loss)\n",
        "print(\"Accuracy: \", accuracy)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
